import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, random_split
from dataset.LDataSet import LDataSet
from model.LCNN import LCNN
from torch import nn

sns.set_style("dark")


def readData(path_info, path_label):
    train_info = pd.read_excel(path_info)
    train_label = pd.read_excel(path_label)

    # 更换train_info的列名
    columns_train_info = ['user_id', 'date', 'sum', 'max', 'mean', 'min']
    train_info.columns = columns_train_info

    # 更换train_label的列名
    columns_train_label = ['user_id', 'user_nature', 'error']
    train_label.columns = columns_train_label

    return train_info, train_label


# 日期标准化，并转为int类型
def getDate(value):
    tmp = value.split("/")
    #     print(tmp)
    if len(tmp[1]) == 1:
        tmp[1] = '0' + tmp[1]
    if len(tmp[2]) == 1:
        tmp[2] = '0' + tmp[2]
    return int(tmp[0] + tmp[1] + tmp[2])


def sortList(list_res):
    # 根据日期大小进行排序
    for i in range(len(list_res)):
        flag = True
        for j in range(1, len(list_res) - i):
            if list_res[j - 1][0] > list_res[j][0]:
                list_res[j - 1], list_res[j] = list_res[j], list_res[j - 1]
                flag = False
        if flag:
            break
    # for i in range(len(list_res)):
    #     del list_res[i][0]
    return list_res


# 升序获得所有的日期数据
def get_all_date(path_info, path_label):
    train_info, train_label = readData(path_info, path_label)
    res_date = []
    tmp = train_info['date'].unique().tolist()
    for date in tmp:
        res_date.append(getDate(date))

    res_date.sort()
    return res_date


def getData(path_info, path_label):
    train_info, train_label = readData(path_info, path_label)

    # 将train_label中的用户性质修改为数字
    train_label.loc[train_label['user_nature'] == "低压居民", "user_nature"] = 0
    train_label.loc[train_label['user_nature'] == "低压非居民", "user_nature"] = 1
    train_label.loc[train_label['user_nature'] == "高压", "user_nature"] = 2

    # 将train_label中的异常标识修改为数字
    train_label.loc[train_label['error'] == "无异常", "error"] = 0
    train_label.loc[train_label['error'] == "有异常", "error"] = 1

    # 将train_info与train_label合并 并去除掉error列
    data_train_info = pd.merge(train_info, train_label, on="user_id")
    data_train_info.drop("error", axis=1, inplace=True)

    # 将所有数据转为三维数据
    ids = data_train_info['user_id'].unique().tolist()
    data_train = []
    for user in ids:
        temp = data_train_info.loc[data_train_info['user_id'] == user]
        temp.drop('user_id', axis=1, inplace=True)
        list = []
        for i in range(len(temp)):
            tmp = []
            for j in range(6):
                if j == 0:
                    tmp.append(getDate(temp.iloc[i, j]))
                else:
                    tmp.append(temp.iloc[i, j])
            list.append(tmp)
        list = sortList(list)
        data_train.append(list)

    # 获取每一个用户的异常标识
    data_train_label = train_label['error'].tolist()

    return data_train, data_train_label


# 20210311这个日期的数据只有第78（i=77）个数据有，可以考虑对其进行删除
# 数据填充方法， 将缺失日期的相关数据全部用0填充
def data_fill(data_train, path_info, path_label):
    # print(data_train)
    date_list = get_all_date(path_info, path_label)

    # 每一家用户的数据
    for user in range(len(data_train)):
        # 每一家用户70天的数据与实际上的数据进行对比
        step = 0  # 这里记录的是日期列表的index
        for day in range(len(data_train[user])):
            # 如果说当前用户的电量日期数据与所有的日期数据是不一样的
            # 这里日期只有三种可能性 一种是当前用户日期与列表日期是相同的 另一种是当前用户日期大于列表日期 最后一种是用户结尾日期小于当前列表日期
            if data_train[user][day][0] > date_list[step]:
                # 开启循环 直到当前用户日期数据与日期列表中第step的数据是一样的
                while data_train[user][day][0] != date_list[step]:
                    temp_data = [date_list[step], 0, 0, 0, 0, data_train[user][day][5]]
                    data_train[user].append(temp_data)
                    step += 1
            # 这里是单独考虑的 只有一个用户的数据有20210311 这里单独提出来进行处理
            elif data_train[user][day][0] == 20210311:
                del data_train[user][day]
            # 这里应该考虑到最后 结尾日期 小于 当前date_list[step]日期
            # 但是由于len是固定的 所以只能单独提出来
            # elif user == 165 and data_train[user][day][0] < date_list[step]:
            #     print(111111111)
            #     while step < len(date_list):
            #         temp_data = [date_list[step], 0, 0, 0, 0, data_train[user][day][5]]
            #         data_train[user].append(temp_data)
            #         step += 1
            #         day += 1
            step += 1
        # 填充完数据之后，对其日期进行升序排列
        data_train[user] = sortList(data_train[user])

    # 因为只有一个数据存在这种情况
    # 这里不单独列出来其实也能写 偷个懒而已
    temp_data1 = [20210309, 0, 0, 0, 0, data_train[165][0][5]]
    temp_data2 = [20210310, 0, 0, 0, 0, data_train[165][0][5]]
    data_train[165].append(temp_data1)
    data_train[165].append(temp_data2)

    return data_train


# 数据标准化 并变成tensor的列表
def standardize(data_train):
    for i in range(len(data_train)):
        scalar = StandardScaler(copy=True, with_mean=True, with_std=True)
        scalar.fit(data_train[i])
        data_train[i] = scalar.transform(data_train[i])
        data_train[i] = torch.Tensor(data_train[i])
        # data_train[i] = data_train[i].reshape(1,data_train[i].shape[0], data_train[i].shape[1])
        # 对数据进行扩维处理 [69, 6] -> [1, 69, 6]
        data_train[i] = data_train[i].unsqueeze(0)

    return data_train


def get_train_and_test_dataset(path_info, path_label):
    data_train_info, data_train_label = getData(path_info, path_label)
    data_train_info = data_fill(data_train_info, path_info, path_label)
    data_train_info = standardize(data_train_info)
    dataset = LDataSet(data_train_info, data_train_label)

    train_size = int(len(dataset) * 0.7)
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    return train_dataset, test_dataset


if __name__ == "__main__":
    path_info = "../dataset/train_info.xlsx"
    path_label = "../dataset/train_label.xlsx"
    train_dataset, test_dataset = get_train_and_test_dataset(path_info, path_label)
    num_train_error = 0
    num_train_normal = 0
    for data in train_dataset:
        info, label = data
        if label == 0:
            num_train_normal += 1
        elif label == 1:
            num_train_error += 1
        else:
            print("label 既不是0 也不是1！")
    print(f"异常用户在训练集中的占比：{num_train_error/len(train_dataset)}")
    print(f"正常用户在训练集中的占比：{num_train_normal/len(train_dataset)}")

    num_test_error = 0
    num_test_normal = 0
    for data in test_dataset:
        info, label = data
        if label == 0:
            num_test_normal += 1
        elif label == 1:
            num_test_error += 1
        else:
            print("label 既不是0 也不是1")
    print(f"异常用户在测试集中的占比：{num_test_error/len(test_dataset)}")
    print(f"正常用户在测试集中的占比：{num_test_normal/len(test_dataset)}")

    # test_data_size = len(test_dataset)
    # train_loader = DataLoader(train_dataset, batch_size=10)
    # test_loader = DataLoader(test_dataset, batch_size=10)
    # data_train_info, data_train_label = getData(path_info, path_label)
    # data_train_info = data_fill(data_train_info, path_info, path_label)
    # print(data_train_info[0][67][0])
    # print(data_train_info[0][68][0])
    #
    # print(data_train_info[165][66][0])
    #
    # print(len(data_train_info[0]))
    # print(len(data_train_info[165]))
    # j = 0
    # for i in range(67):
    #     if data_train_info[0][j][0] == data_train_info[165][i][0]:
    #         print("i:" + str(i))
    #         print("j:" + str(j))
    #         j += 1
    #     else:
    #         print(j)
    #         break

    # data_train_info = standardize(data_train_info)
    # dataset = LDataSet(data_train_info, data_train_label)
    # print(len(dataset))
    # dataset[0][0] = dataset[0][0].unsqueeze(0)
    # print(dataset[0][0].shape)

    # len_list = []
    # for i in range(len(dataset)):
    #     data, label = dataset[i]
    #     if len(data) not in len_list:
    #         print(i)
    #         len_list.append(len(data))
    # print(len_list)

    # train_loader = DataLoader(dataset, batch_size=10)
    # for data in train_loader:
    #     datas, label = data
    #     print(datas.shape)
    #     print(label)

    # lcnn = LCNN()
    # # 创建损失函数
    # loss_fn = nn.CrossEntropyLoss()
    #
    # # 优化器
    # # learning_rate = 0.01
    # learning_rate = 1e-2
    # optimizer = torch.optim.SGD(lcnn.parameters(), lr=learning_rate, )
    # # 记录训练网络的一些参数
    # # 记录训练的次数
    # total_train_step = 0
    # # 记录测试的次数
    # total_test_step = 0
    # # 训练的轮数
    # epoch = 10
    #
    # for i in range(epoch):
    #     print("------第{}轮训练开始------".format(i))
    #
    #     # 训练步骤开始
    #     for data in train_loader:
    #         imgs, targets = data
    #         outputs = lcnn(imgs)
    #         loss = loss_fn(outputs, targets)
    #
    #         # 优化器调优
    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #
    #         total_train_step += 1
    #         print("训练次数：{}, Loss:{}".format(total_train_step, loss.item()))
    #
    #         # 测试步骤开始
    #         total_test_loss = 0
    #         total_accuracy = 0
    #         with torch.no_grad():
    #             for data in test_loader:
    #                 imgs, targets = data
    #                 outputs = lcnn(imgs)
    #                 loss = loss_fn(outputs, targets)
    #                 total_test_loss += loss.item()
    #                 accuracy = (outputs.argmax(1) == targets).sum()
    #                 total_accuracy += accuracy
    #
    #         print("整体测试集上的loss:{}".format(total_test_loss))
    #         print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))


    # print(dataset[0][0].shape)
    # print(len(dataset))
    # data_fill(1)
    # print()
    # for i in range(len(data_train_info)):
    #     print(len(data_train_info[i]))

    # x_train, x_test, y_train, y_test = train_test_split(data_train_info, data_train_label, test_size=0.2, shuffle=True)
    # print("x_train size: {}, y_train size: {}".format(len(x_train), len(y_train)))
    # print("x_test size: {}, y_test size: {}".format(len(x_test), len(y_test)))
    # print(y_test)
    # print(len(data_train_info))
    # for i in range(len(data_train_info)):
    #     print(f"第{i+1}个数据的shape为{data_train_info[i].shape}")

    # a = torch.randn(4, 3, 28, 28)
    # print('a.shape\n', a.shape)
    #
    # print('\n维度扩展(变成5维的)：')
    # print('第0维前加1维')
    # print(a.unsqueeze(0).shape)
    # print('第4维前加1维')
    # print(a.unsqueeze(4).shape)
    # print('在-1维前加1维')
    # print(a.unsqueeze(-1).shape)
    # print('在-4维前加1维')
    # print(a.unsqueeze(-4).shape)
    # print('在-5维前加1维')
    # print(a.unsqueeze(-5).shape)

